import nltk
from nltk.corpus import wordnet as wn
from collections import defaultdict
import networkx as nx
import matplotlib.pyplot as plt
import os, sys, re , time
# Ensure you have downloaded necessary data
nltk.download('omw-1.4')
nltk.download('wordnet')

# PATHS
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
WORDNET_DIR = os.path.join(PROJECT_ROOT, '../../../data/wordnet')
FIGURE_PATH = os.path.join(PROJECT_ROOT, '../../../output/wordnet_graph_es.svg')

# Create the directory if it does not exist
if not os.path.exists(WORDNET_DIR):
    os.makedirs(WORDNET_DIR)
    # Download WordNet data
    nltk.download('wordnet', download_dir=WORDNET_DIR)
    nltk.download('omw-1.4', download_dir=WORDNET_DIR)
    # Set the WordNet data directory
    nltk.data.path.append(WORDNET_DIR)
    # Set the WordNet data directory for WordNet
    wn._dir = WORDNET_DIR
    # Set the WordNet data directory for OMW
    wn._omw_dir = WORDNET_DIR
# Set the language to Spanish
wn._lang = 'spa'

# Function to get related words and assign weights
def get_related_words(seed_words, max_depth=2):
    weighted_dict = defaultdict(int)
    
    def add_related_words(word, weight):
        synsets = wn.synsets(word, lang='spa')
        for synset in synsets:
            for lemma in synset.lemmas('spa'):
                weighted_dict[lemma.name()] += weight

    # Initialize the dictionary with seed words
    for seed_word in seed_words:
        weighted_dict[seed_word] = 1

    # Expand the dictionary based on WordNet relations
    for seed_word in seed_words:
        current_words = {seed_word}
        current_weight = 1
        
        for depth in range(max_depth):
            next_words = set()
            current_weight /= 2  # Halve the weight at each level of depth
            
            for word in current_words:
                synsets = wn.synsets(word, lang='spa')
                for synset in synsets:
                    for lemma in synset.lemmas('spa'):
                        lemma_name = lemma.name()
                        if lemma_name not in weighted_dict:
                            next_words.add(lemma_name)
                            weighted_dict[lemma_name] = current_weight
            
            current_words = next_words
    
    return weighted_dict

# Seed words in Spanish
seed_words = [
    "violencia",
    "asesinato",
    "homicidio",
    "tiroteo",
    "ataque",
    "enfrentamiento",
    "balacera",
    "secuestro",
    "narcotráfico",
    "delincuencia"
]

# Construct the weighted dictionary
weighted_dict = get_related_words(seed_words)

# Create a graph
G = nx.Graph()

# Add nodes and edges
for word, weight in weighted_dict.items():
    G.add_node(word, weight=weight, seed=word in seed_words)

for word in seed_words:
    synsets = wn.synsets(word, lang='spa')
    for synset in synsets:
        for lemma in synset.lemmas('spa'):
            lemma_name = lemma.name()
            if lemma_name in weighted_dict:
                G.add_edge(word, lemma_name, weight=weighted_dict[lemma_name])

# Visualize the graph
pos = nx.kamada_kawai_layout(G)  # positions for all nodes

# Node sizes based on weight
node_sizes = [G.nodes[node]['weight'] * 1000 for node in G]
node_colors = ['lightcoral' if G.nodes[node]['seed'] else 'skyblue' for node in G]

plt.figure(figsize=(12, 8))
nx.draw(G, pos, with_labels=True, node_size=node_sizes, node_color=node_colors, font_size=8, font_weight='bold', edge_color='gray', width=1.0)
nx.draw_networkx_edge_labels(G, pos, edge_labels={(u, v): f'{d:.2f}' for u, v, d in G.edges(data='weight')}, font_size=6)

plt.title("WordNet Graph of Related Spanish Words", fontsize=14)

# Save the graph as an SVG file
plt.savefig(FIGURE_PATH, format='svg')

plt.show()
